﻿#pragma kernel Project
#pragma kernel Apply

#include "MathUtils.cginc"
#include "AtomicDeltas.cginc"

StructuredBuffer<int> orientationIndices;
StructuredBuffer<float3> stiffnesses;
StructuredBuffer<float2> plasticity;
RWStructuredBuffer<quaternion> restDarboux;
RWStructuredBuffer<float3> lambdas;

RWStructuredBuffer<quaternion> orientations;
StructuredBuffer<float> invRotationalMasses;

// Variables set from the CPU
uint activeConstraintCount;
float deltaTime;
float sorFactor;

[numthreads(128, 1, 1)]
void Project (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= activeConstraintCount) return;

    int q1 = orientationIndices[i * 2];
    int q2 = orientationIndices[i * 2 + 1];

    float w1 = invRotationalMasses[q1];
    float w2 = invRotationalMasses[q2];

    // calculate time adjusted compliance
    float3 compliances = stiffnesses[i] / (deltaTime * deltaTime);

    // rest and current darboux vectors
    quaternion rest = restDarboux[i];
    quaternion omega = qmul(q_conj(orientations[q1]), orientations[q2]);

    quaternion omega_plus;
    omega_plus = omega + rest;  //delta Omega with - omega_0
    omega -= rest;              //delta Omega with + omega_0

    if (dot(omega,omega) > dot(omega_plus,omega_plus))
        omega = omega_plus;

    // plasticity
    if (dot(omega.xyz, omega.xyz) > plasticity[i].x * plasticity[i].x)
    {
        rest += omega * plasticity[i].y * deltaTime;
        restDarboux[i] = rest;
    }

    float3 dlambda = (omega.xyz - compliances * lambdas[i]) / (compliances + w1 + w2 + EPSILON);

    //discrete Darboux vector does not have vanishing scalar part
    quaternion dlambdaQ = quaternion(dlambda[0], dlambda[1], dlambda[2],0);
    
    AddOrientationDelta(q1, qmul(orientations[q2], dlambdaQ) * w1);
    AddOrientationDelta(q2,-qmul(orientations[q1], dlambdaQ) * w2);
   
    lambdas[i] += dlambda;
}

[numthreads(128, 1, 1)]
void Apply (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= activeConstraintCount) return;

    int q1 = orientationIndices[i * 2];
    int q2 = orientationIndices[i * 2 + 1];

    ApplyOrientationDelta(orientations, q1, sorFactor);
    ApplyOrientationDelta(orientations, q2, sorFactor);
}